import pandas as pd
from json import load
import plotly.express as px
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn import metrics
import plotly.express as px
df = pd.read_csv("COVIDiSTRESS_April_May_Combined.csv")
df["latitudex"] = df["latitude"]*1000
df["longitudex"] = df["longitude"]*1000
df["latitudex"] = df["latitudex"].apply(int)
df["longitudex"] = df["longitudex"].apply(int)
df.head()
| Unnamed: 0 | PSS10_avg | latitude | longitude | Trust_countrymeasure | Lon_avg | Dem_age | Dem_edu | Dem_Expat | Dem_dependents | ... | Married/cohabiting | Single | Uninformative response | 1 | Isolated | Isolated in medical facility of similar location | Life carries on as usual | Life carries on with minor changes | latitudex | longitudex | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 0.32 | 0.622222 | 0.055556 | 0.5 | 0.466667 | -0.416931 | 0.833333 | 1.0 | 0.000000 | ... | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 622 | 55 |
| 1 | 1 | 0.48 | 0.540741 | 0.108333 | 0.5 | 0.733333 | 0.989318 | 0.833333 | 0.0 | 0.000000 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 540 | 108 |
| 2 | 2 | 0.62 | 0.540741 | 0.108333 | 0.5 | 0.533333 | 0.619253 | 0.666667 | 0.0 | 0.582783 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 540 | 108 |
| 3 | 3 | 0.38 | 0.511111 | 0.011111 | 0.2 | 0.266667 | 0.545240 | 0.833333 | 0.0 | 0.321513 | ... | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 511 | 11 |
| 4 | 4 | 0.54 | 0.540741 | 0.108333 | 0.8 | 0.466667 | 1.581423 | 1.000000 | 0.0 | 0.000000 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 540 | 108 |
5 rows × 60 columns
# Loading country codes json file to convert latitude and longitude to country
with open("country-codes-lat-long-alpha3.json", "r") as f:
c_j = load(f)
country_df = pd.DataFrame().from_dict(c_j["ref_country_codes"]).loc[:, ["alpha3", "latitude", "longitude", "country"]]
country_df["latitudex"] = country_df["latitude"]/90.0*1000
country_df["longitudex"] = country_df["longitude"]/180.0*1000
country_df["latitudex"] = country_df["latitudex"].apply(int)
country_df["longitudex"] = country_df["longitudex"].apply(int)
country_df.drop(columns=["latitude", "longitude"], inplace=True)
country_df
| alpha3 | country | latitudex | longitudex | |
|---|---|---|---|---|
| 0 | ALB | Albania | 455 | 111 |
| 1 | DZA | Algeria | 311 | 16 |
| 2 | ASM | American Samoa | -159 | -944 |
| 3 | AND | Andorra | 472 | 8 |
| 4 | AGO | Angola | -138 | 102 |
| ... | ... | ... | ... | ... |
| 242 | AFG | Afghanistan | 366 | 361 |
| 243 | Kosovo | 473 | 116 | |
| 244 | Laos | 221 | 569 | |
| 245 | Sudan, South | 76 | 173 | |
| 246 | other | 0 | 0 |
247 rows × 4 columns
# Merging dataframes to get country values
df=df.merge(country_df, how="left", on=["latitudex", "longitudex"]).drop(columns=["latitudex", "longitudex"])
df
| Unnamed: 0 | PSS10_avg | latitude | longitude | Trust_countrymeasure | Lon_avg | Dem_age | Dem_edu | Dem_Expat | Dem_dependents | ... | Married/cohabiting | Single | Uninformative response | 1 | Isolated | Isolated in medical facility of similar location | Life carries on as usual | Life carries on with minor changes | alpha3 | country | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 0.32 | 0.622222 | 0.055556 | 0.5 | 0.466667 | -0.416931 | 0.833333 | 1.0 | 0.000000 | ... | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | DNK | Denmark |
| 1 | 1 | 0.48 | 0.540741 | 0.108333 | 0.5 | 0.733333 | 0.989318 | 0.833333 | 0.0 | 0.000000 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | SVK | Slovakia |
| 2 | 2 | 0.62 | 0.540741 | 0.108333 | 0.5 | 0.533333 | 0.619253 | 0.666667 | 0.0 | 0.582783 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | SVK | Slovakia |
| 3 | 3 | 0.38 | 0.511111 | 0.011111 | 0.2 | 0.266667 | 0.545240 | 0.833333 | 0.0 | 0.321513 | ... | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | FRA | France |
| 4 | 4 | 0.54 | 0.540741 | 0.108333 | 0.8 | 0.466667 | 1.581423 | 1.000000 | 0.0 | 0.000000 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | SVK | Slovakia |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 173009 | 89830 | 0.34 | 0.622222 | 0.055556 | 0.6 | 0.466667 | -0.786996 | 0.833333 | 1.0 | 0.000000 | ... | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | DNK | Denmark |
| 173010 | 89831 | 0.60 | 0.622222 | 0.055556 | 0.4 | 0.466667 | -0.416931 | 0.666667 | 1.0 | 0.000000 | ... | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | DNK | Denmark |
| 173011 | 89832 | 0.28 | 0.622222 | 0.055556 | 0.4 | 0.266667 | 1.359384 | 0.500000 | 1.0 | 0.321513 | ... | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | DNK | Denmark |
| 173012 | 89833 | 0.38 | 0.622222 | 0.055556 | 0.8 | 0.333333 | 1.137345 | 0.666667 | 1.0 | 0.000000 | ... | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | DNK | Denmark |
| 173013 | 89834 | 0.72 | 0.622222 | 0.055556 | 1.0 | 0.866667 | 1.359384 | 0.500000 | 1.0 | 0.000000 | ... | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | DNK | Denmark |
173014 rows × 60 columns
# Grouping by country to get average baseline stress
mean_stress = df.groupby("country").mean().reset_index().iloc[1:, :]
mean_stress
| country | Unnamed: 0 | PSS10_avg | latitude | longitude | Trust_countrymeasure | Lon_avg | Dem_age | Dem_edu | Dem_Expat | ... | Student | Divorced/widowed | Married/cohabiting | Single | Uninformative response | 1 | Isolated | Isolated in medical facility of similar location | Life carries on as usual | Life carries on with minor changes | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | Albania | 36826.571429 | 0.580514 | 0.455556 | 0.111111 | 0.619465 | 0.510099 | -0.568481 | 0.730159 | 0.142857 | ... | 0.142857 | 0.047619 | 0.285714 | 0.619048 | 0.0 | 0.0 | 0.714286 | 0.000000 | 0.000000 | 0.285714 |
| 2 | Algeria | 54860.897436 | 0.601538 | 0.311111 | 0.016667 | 0.279487 | 0.548718 | -0.365691 | 0.820513 | 0.282051 | ... | 0.179487 | 0.051282 | 0.589744 | 0.358974 | 0.0 | 0.0 | 0.435897 | 0.000000 | 0.000000 | 0.564103 |
| 3 | Andorra | 37004.411765 | 0.496471 | 0.472222 | 0.008889 | 0.471104 | 0.466667 | -0.347271 | 0.568627 | 0.470588 | ... | 0.000000 | 0.117647 | 0.294118 | 0.352941 | 0.0 | 0.0 | 0.294118 | 0.000000 | 0.352941 | 0.352941 |
| 4 | Angola | 55075.500000 | 0.520000 | -0.138889 | 0.102778 | 0.600000 | 0.633333 | -0.712983 | 0.916667 | 0.500000 | ... | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 0.0 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 1.000000 |
| 5 | Antigua and Barbuda | 24005.333333 | 0.462222 | 0.189444 | -0.343333 | 0.166667 | 0.377778 | 0.495898 | 0.722222 | 0.000000 | ... | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 0.0 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 1.000000 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 154 | Venezuela | 58180.687500 | 0.553750 | 0.088889 | -0.366667 | 0.387500 | 0.533333 | 0.226058 | 0.760417 | 0.000000 | ... | 0.125000 | 0.187500 | 0.125000 | 0.687500 | 0.0 | 0.0 | 0.812500 | 0.000000 | 0.000000 | 0.187500 |
| 155 | Vietnam | 61977.884615 | 0.514051 | 0.177778 | 0.588889 | 0.542308 | 0.457436 | -1.028393 | 0.746154 | 0.069231 | ... | 0.361538 | 0.015385 | 0.230769 | 0.700000 | 0.0 | 0.0 | 0.200000 | 0.030769 | 0.084615 | 0.684615 |
| 156 | Zambia | 38966.000000 | 0.400000 | -0.166667 | 0.166667 | 0.200000 | 0.466667 | 1.211358 | 0.833333 | 1.000000 | ... | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 0.0 | 0.0 | 0.500000 | 0.000000 | 0.000000 | 0.500000 |
| 157 | Zimbabwe | 49158.000000 | 0.560000 | -0.222222 | 0.166667 | 0.500000 | 0.266667 | -1.157062 | 0.833333 | 0.000000 | ... | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 0.0 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 1.000000 |
| 158 | other | 38660.050420 | 0.530556 | 0.000000 | 0.000000 | 0.426198 | 0.502801 | 0.177040 | 0.698179 | 0.184874 | ... | 0.151261 | 0.117647 | 0.474790 | 0.365546 | 0.0 | 0.0 | 0.340336 | 0.000000 | 0.147059 | 0.512605 |
158 rows × 59 columns
# Plotting baseline
fig = px.choropleth(df, locations="country", locationmode = 'country names', color="PSS10_avg", color_continuous_scale=px.colors.sequential.YlOrRd)
fig.update_layout(
geo=dict(
showframe=False,
showcoastlines=False,
projection_type='equirectangular'
))
fig.show()
df = df.drop(["Unnamed: 0", "1", "Uninformative response"], axis="columns")
#mean_stress = mean_stress.drop(["country"], axis="columns")
x = df.drop(["PSS10_avg", "country", "alpha3"], axis="columns").to_numpy()
class StressNN(nn.Module):
def __init__(self, input_size = 54):
super(StressNN, self).__init__()
self.LinBlock = nn.Sequential(
nn.Linear(input_size, 256),
nn.LeakyReLU(),
nn.Linear(256, 128),
nn.LeakyReLU(),
nn.Linear(128, 64),
nn.LeakyReLU(),
nn.Linear(64, 32),
nn.LeakyReLU(),
nn.Linear(32, 16),
nn.LeakyReLU(),
nn.Linear(16, 1)
)
def forward(self, x):
x = self.LinBlock(x)
#x = torch.sigmoid(x)
return x
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = StressNN(x.shape[1])
model.load_state_dict(torch.load("model_april_may_mse.pt", map_location=torch.device('cpu')))
<All keys matched successfully>
model.eval()
StressNN(
(LinBlock): Sequential(
(0): Linear(in_features=54, out_features=256, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=256, out_features=128, bias=True)
(3): LeakyReLU(negative_slope=0.01)
(4): Linear(in_features=128, out_features=64, bias=True)
(5): LeakyReLU(negative_slope=0.01)
(6): Linear(in_features=64, out_features=32, bias=True)
(7): LeakyReLU(negative_slope=0.01)
(8): Linear(in_features=32, out_features=16, bias=True)
(9): LeakyReLU(negative_slope=0.01)
(10): Linear(in_features=16, out_features=1, bias=True)
)
)
# Predicting strees for every individual using the model
outputs = model(torch.from_numpy(x).float())
out = outputs.detach().numpy()
df['out'] = out
# Grouping by country to get average predicted stress
preds = df.groupby("country").mean().reset_index().iloc[1:, :]
preds.head()
| country | PSS10_avg | latitude | longitude | Trust_countrymeasure | Lon_avg | Dem_age | Dem_edu | Dem_Expat | Dem_dependents | ... | Self-employed | Student | Divorced/widowed | Married/cohabiting | Single | Isolated | Isolated in medical facility of similar location | Life carries on as usual | Life carries on with minor changes | out | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | Albania | 0.580514 | 0.455556 | 0.111111 | 0.619465 | 0.510099 | -0.568481 | 0.730159 | 0.142857 | 0.289952 | ... | 0.285714 | 0.142857 | 0.047619 | 0.285714 | 0.619048 | 0.714286 | 0.0 | 0.000000 | 0.285714 | 0.553827 |
| 2 | Algeria | 0.601538 | 0.311111 | 0.016667 | 0.279487 | 0.548718 | -0.365691 | 0.820513 | 0.282051 | 0.271046 | ... | 0.179487 | 0.179487 | 0.051282 | 0.589744 | 0.358974 | 0.435897 | 0.0 | 0.000000 | 0.564103 | 0.571087 |
| 3 | Andorra | 0.496471 | 0.472222 | 0.008889 | 0.471104 | 0.466667 | -0.347271 | 0.568627 | 0.470588 | 0.294693 | ... | 0.176471 | 0.000000 | 0.117647 | 0.294118 | 0.352941 | 0.294118 | 0.0 | 0.352941 | 0.352941 | 0.499212 |
| 4 | Angola | 0.520000 | -0.138889 | 0.102778 | 0.600000 | 0.633333 | -0.712983 | 0.916667 | 0.500000 | 0.160756 | ... | 0.500000 | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 0.000000 | 0.0 | 0.000000 | 1.000000 | 0.618941 |
| 5 | Antigua and Barbuda | 0.462222 | 0.189444 | -0.343333 | 0.166667 | 0.377778 | 0.495898 | 0.722222 | 0.000000 | 0.214342 | ... | 0.000000 | 0.000000 | 0.000000 | 1.000000 | 0.000000 | 0.000000 | 0.0 | 0.000000 | 1.000000 | 0.456178 |
5 rows × 57 columns
fig = px.choropleth(preds, locations="country", locationmode = 'country names', color="out", color_continuous_scale=px.colors.sequential.YlOrRd)
fig.update_layout(
geo=dict(
showframe=False,
showcoastlines=False,
projection_type='equirectangular'
))
fig.show()
out[:50]
array([[0.39628884],
[0.60251015],
[0.593424 ],
[0.35520253],
[0.48110136],
[0.50212055],
[0.6931589 ],
[0.6849951 ],
[0.5271847 ],
[0.47137263],
[0.40228146],
[0.5061851 ],
[0.62119526],
[0.53120637],
[0.535346 ],
[0.45351908],
[0.5518916 ],
[0.65685797],
[0.5931185 ],
[0.49528977],
[0.29172155],
[0.47207937],
[0.5555323 ],
[0.44992605],
[0.4077221 ],
[0.37738153],
[0.39731947],
[0.40041593],
[0.52745354],
[0.45373526],
[0.47814596],
[0.48705766],
[0.52141714],
[0.47656056],
[0.479418 ],
[0.4147366 ],
[0.35151276],
[0.65657574],
[0.38083208],
[0.5246014 ],
[0.73703784],
[0.5028131 ],
[0.5506441 ],
[0.65674025],
[0.74531347],
[0.60975367],
[0.547792 ],
[0.4736585 ],
[0.6052245 ],
[0.5773858 ]], dtype=float32)
px.histogram(df, "out")